import mxnet as mx
import numpy as np
import argparse
import json
from mxnet.contrib import onnx as onnx_mxnet
from onnx import checker
import onnx
import logging
logging.basicConfig(level=logging.INFO)

def mxnet2onnx(mx_json, mx_params, onnx_file):
    print("mx_json:%s, mx_params:%s, onnx_file:%s"%(mx_json, mx_params, onnx_file))
    # Imagenet输入
    input_shape = (1,3,512,512)

    # 调用导出模型API。它返回转换后的onnx模型的路径
    converted_model_path = onnx_mxnet.export_model(mx_json, mx_params, [input_shape], np.float32, onnx_file)
    print("converted model path:", converted_model_path)
    # Load onnx model
    model_proto = onnx.load_model(converted_model_path)
    # Check if converted ONNX protobuf is valid, if invalid then exception
    checker.check_graph(model_proto.graph)

def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default="./net_inst.txt", type=str, required=True, help='config file for inference')
    return parser

if __name__ == "__main__":
    myparser = make_parser()
    args = myparser.parse_args()
    config_file = args.config
    file_config = open(args.config).read()
    lines_config = file_config.split("\n")
    for i in range(0,len(lines_config)):
        if "mxnet_json" in lines_config[i]:
            mxnet_json=lines_config[i].split("=")[-1]
        elif "mxnet_params" in lines_config[i]:
            mxnet_params=lines_config[i].split("=")[-1]
        elif "inst_file" in lines_config[i]:
            onnx_file = lines_config[i].split("=")[-1] + ".onnx"

    mxnet2onnx(mxnet_json, mxnet_params, onnx_file)
